/**
* Copyright (C) 2014 zml (netevents@zachsthings.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.zachsthings.netevents.sec;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.KeySpec;
/**
* Socket wrapper that provides an encrypted channel
*/
public class AESSocketWrapper implements SocketWrapper {
private static final byte[] SALT = new byte[]{8, 12, 16, 84, 98, 93, 92, 23, 38, 3};
private static final int ITER_COUNT = 1024, KEY_LEN = 128;
private final String passphrase;
public AESSocketWrapper(String passphrase) {
this.passphrase = passphrase;
}
@Override
public SocketChannel wrapSocket(SocketChannel chan) throws IOException {
try {
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1");
KeySpec spec = new PBEKeySpec(passphrase.toCharArray(), SALT, ITER_COUNT, KEY_LEN);
SecretKey secretKey = factory.generateSecret(spec);
Key key = new SecretKeySpec(secretKey.getEncoded(), "AES");
AlgorithmParameters params = AlgorithmParameters.getInstance("AES");
params.init(new IvParameterSpec(new byte[16]));
return new CryptSocketChannel(chan, key, params);
} catch (NoSuchAlgorithmException
| InvalidKeySpecException
| NoSuchPaddingException
| InvalidKeyException
| InvalidAlgorithmParameterException
| InvalidParameterSpecException e) {
throw new IOException(e);
}
}
private static class CryptSocketChannel extends WrappedSocketChannel {
private final Object readLock = new Object(),
writeLock = new Object();
private ByteBuffer readTmp;
private final Cipher enc, dec;
public CryptSocketChannel(SocketChannel wrappee, Key k, AlgorithmParameters params) throws InvalidKeyException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException {
super(wrappee);
enc = Cipher.getInstance("AES/CBC/PKCS5Padding");
enc.init(Cipher.ENCRYPT_MODE, k, params);
dec = Cipher.getInstance("AES/CBC/PKCS5Padding");
dec.init(Cipher.DECRYPT_MODE, k, params);
}
/**
* Only call under readLock
*
* @param capacity required capacity
* @return adjusted readTmp buffer
*/
private ByteBuffer adjustReadTmp(int capacity) {
if (readTmp == null || readTmp.capacity() < capacity
|| (capacity > 16 && readTmp.capacity() > (4 * capacity))) { // If it's not a small buffer and readTmp is pretty large, lets shrink it
readTmp = ByteBuffer.allocate(capacity);
}
return readTmp;
}
@Override
public int read(ByteBuffer dst) throws IOException {
int startIdx = dst.position();
ByteBuffer src = ByteBuffer.allocate(paddedLen(dst));
int read = super.read(src);
if (read <= 0) { // Nothing or closed channel
return read;
}
src.flip();
synchronized (readLock) {
try {
adjustReadTmp(src.capacity());
try {
dec.doFinal(src, readTmp);
} catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) {
// If one of these happens, it's probably an incorrect passphrase
throw new IOException("Invalid data received from remote! Do passphrases match?", e);
}
readTmp.flip();
dst.put(readTmp);
} finally {
if (readTmp != null) {
readTmp.clear();
}
}
}
return dst.position() - startIdx;
}
@Override
public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
throw new UnsupportedOperationException("Not supported yet.");
//return super.read(dsts, offset, length);
}
@Override
public int write(ByteBuffer src) throws IOException {
ByteBuffer dst = ByteBuffer.allocate(enc.getOutputSize(src.limit()));
try {
synchronized (writeLock) {
enc.doFinal(src, dst);
}
} catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) {
throw new IOException(e);
}
dst.flip();
return super.write(dst);
}
@Override
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
throw new UnsupportedOperationException("Not supported yet.");
//return super.write(srcs, offset, length);
}
private static int paddedLen(ByteBuffer test) {
return test.capacity() + (16 - test.capacity() % 16);
}
}
}